import numpy as np
import pandas as pd
from scipy.special import gammainc
from scipy.optimize import newton, fsolve
import os
import matplotlib.pyplot as plt
import support_functions as sp
import tikzplotlib

def output_solver(x, A, sigma, superelast):
	x = np.array(x)
	x_norm = normalize_x(x, sigma, superelast)
	prices = markup_kw(np.array(x_norm), sigma, superelast) * 1/np.array(A)
	P_Y = np.sum(prices * x_norm) / len(x_norm)
	P = P_Y / (np.sum(upsilon_prime_kw(x_norm, sigma, superelast) * x_norm) / len(x_norm))
	output_pred = upsilon_prime_inv_kw(prices / P, sigma, superelast)
	return output_pred - x

def normalize_x(x, sigma, superelast):
	# Constraint is that sum of Upsilon(x) = 1
	def normalize_helper(normalizing_factor, x, sigma, superelast):
		return 1 - np.sum(upsilon_kw(np.array(x)/normalizing_factor, sigma, superelast))/len(x)
	res = newton(normalize_helper, 1, args=(x, sigma, superelast))
	print(res)
	return np.array(x)/res

def upsilon_kw(x, sigma, superelast):
	temp = gammainc(sigma/superelast, 1/superelast) - gammainc(sigma/superelast, x**(superelast/sigma)/superelast)
	return 1 + (sigma - 1) * np.exp(1/superelast) * superelast**(sigma / superelast - 1) * temp

def upsilon_prime_kw(x, sigma, superelast):
	return (sigma-1)/sigma * np.exp((1 - x**(superelast/sigma))/superelast)

def upsilon_prime_inv_kw(x, sigma, superelast):
	temp = 1 + superelast * np.log((sigma - 1)/(sigma * x))
	return max(temp, 0)**(sigma/superelast)
upsilon_prime_inv_kw = np.vectorize(upsilon_prime_inv_kw)

def markup_kw(x, sigma, superelast):
	return 1/(1-1/elasticity_kw(x, sigma, superelast))

def elasticity_kw(x, sigma, superelast):
	return sigma * x**(-superelast/sigma)

def passthrough_kw(x, sigma, superelast):
	dlogmu_dlogx = superelast/(sigma**2) * x**(superelast/sigma) / (1 - 1/sigma * x**(superelast/sigma))
	return 1/(1 + dlogmu_dlogx*elasticity_kw(x, sigma, superelast))

def shares_kw(x, A, sigma, superelast):
	prices = markup_kw(x, sigma, superelast) * (1/A)
	shares = prices*x
	return shares/np.sum(shares)

def draw_firm_productivities(N_firms, pareto_scale=8):
	# Amiti et al version: Pareto with shape = 8 and lower bound = exp(0)
	firms = np.random.pareto(pareto_scale, N_firms) + 1
	return np.array(sorted(firms))

def generate_kw_distributions(N_firms=1000, sigma=5, superelast=1.6, seed=True):
	filename = 'data/klenow_willis_store/' + str(superelast) + '_' + str(N_firms) + '.store'
	if os.path.exists(filename):
		saved = np.fromfile(filename)
		saved = np.reshape(saved, (5, int(len(saved)/5)))
		print(saved.shape)
		return saved[0,:],saved[1,:],saved[2,:]
	if seed:
		np.random.seed(23)
	A = draw_firm_productivities(N_firms)
	x_res = fsolve(output_solver, A, args=(A, sigma, superelast))
	markups = markup_kw(x_res, sigma, superelast)
	passthroughs = passthrough_kw(x_res, sigma, superelast)
	shares = shares_kw(x_res, A, sigma, superelast)
	to_save = np.matrix([shares, markups, passthroughs, A, x_res])
	print(to_save.shape)
	to_save.tofile(filename)
	return shares,markups,passthroughs

def plot_kw_distributions(N_firms=1000, sigma=5, superelast=1.6, seed=True, savefileroot=None):
	shares,markups,passthroughs = generate_kw_distributions(N_firms, sigma, superelast, seed)
	fig,axs = plt.subplots(3,1, figsize=(4,8))
	axs[0].plot(shares)
	axs[0].set_ylabel('$\\lambda_\\theta$')
	axs[1].plot(markups)
	axs[1].set_ylabel('$\\mu_\\theta$')
	axs[2].plot(passthroughs)
	axs[2].set_ylabel('$\\rho_\\theta$')
	axs[2].set_ylim([0, 1])
	axs[2].set_xlabel('Firms ordered by size')
	plt.tight_layout()
	plt.figure()
	plt.plot(np.linspace(0,1,N_firms), passthroughs)
	plt.ylim([0.2,1])
	if savefileroot:
		plt.suptitle(None)
		tikzplotlib.clean_figure()
		tikzplotlib.save(savefileroot + '_kw_passthroughs.tex', extra_axis_parameters=['PlotStyle'])
	plt.figure()
	plt.plot(np.linspace(0,1,N_firms), np.log(shares))
	if savefileroot:
		plt.suptitle(None)
		tikzplotlib.clean_figure()
		tikzplotlib.save(savefileroot + '_kw_salesdensity.tex', extra_axis_parameters=['PlotStyle'])


def plot_kw_markups():
	shares,markups,passthroughs,prod,output = sp.get_key_darwinian_data()
	markups = sp.generate_markups(passthroughs, shares, 1.15)
	AMITI_SMALL_FIRMS = 0.97
	eps_sigma = 1/AMITI_SMALL_FIRMS - 1
	sigma = 5
	mu_kw = (1/passthroughs - 1)/eps_sigma
	mu_kw[mu_kw < 1] = np.nan
	plt.plot(np.linspace(0,1,len(shares)), mu_kw, label='Klenow and Willis (2016)')
	plt.plot(np.linspace(0,1,len(shares)), markups, label='Nonparametric approach')
	plt.legend()
	plt.yscale('log')
	plt.plot(mu_kw, passthroughs, label='Klenow and Willis (2016)')
	plt.plot(markups, passthroughs, label='Nonparametric approach')
	plt.legend()
	plt.xscale('log')
	fig,axs = plt.subplots(1,2, figsize=(8,4))
	axs[0].plot(passthroughs, mu_kw, label='Klenow and Willis (2016)')
	axs[1].plot(passthroughs, markups, label='Nonparametric approach')
	#plt.legend()
	#splt.yscale('log')
	axs[0].set_xlim([axs[0].get_xlim()[1], axs[0].get_xlim()[0]])
	axs[1].set_xlim([axs[1].get_xlim()[1], axs[1].get_xlim()[0]])
	axs[0].set_title('Klenow and Willis (2016) preferences')
	axs[1].set_title('Nonparametric approach')
	axs[0].set_xlabel('Passthrough $(\\rho_\\theta)$')
	axs[1].set_xlabel('Passthrough $(\\rho_\\theta)$')
	axs[0].set_ylabel('Implied markup $(\\mu_\\theta)$')
	axs[0].set_yscale('log')
	axs[1].set_yscale('log')
	axs[1].set_ylim(axs[0].get_ylim())
	plt.tight_layout()



#plot_kw_distributions(N_firms=5000, superelast=1.6, savefileroot='../Draft/figures/klenow_willis')
#plt.show()
# shares,markups,passthroughs = generate_kw_distributions(N_firms=5000, superelast=10)
# print(markups, passthroughs)
